
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from CIFAR10_GAN import *


print(torch.cuda.is_available())  # Should return True if CUDA is available
print(torch.cuda.device_count())  # Number of available GPUs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator =  generator.to(device)
discriminator  = discriminator.to(device)


new_path = "/users/eval/discriminative_approach"  # Replace with your desired path
os.chdir(new_path)
current_path = os.getcwd()
print("Current Path:", current_path)


##-----------------data loading-----------------##

# Define the transformations: convert images to PyTorch tensors and normalize them
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize RGB channels
])

# Download and load the CIFAR-10 training dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

# Download and load the CIFAR-10 test dataset
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

##-----------------dataset preview-----------------##

# dataset size
print(f"Number of training samples: {len(trainset)}")
print(f"Number of testing samples: {len(testset)}")

# preview the first image
image, label = trainset[0]
print(f"Image shape: {image.shape}")
print(f"Label: {label}")

# CIFAR-10 class namesclass names
classes = trainset.classes
print(f"Class names: {classes}")

print(f"Data type of image: {type(image)}")
print(f"Shape of image tensor: {image.shape}")

# Get a batch of data
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Preview the shape of the batch
print(f"Batch image tensor shape: {images.shape}")
print(f"Batch label tensor shape: {labels.shape}")

##-----------------classifier model-----------------##

# load the classifier model
model = models.resnet18(pretrained=True)
# Modify the final layer to match the number of classes in CIFAR-10
model.fc = nn.Linear(model.fc.in_features, 10)

state_dict = torch.load('cifar10_classifier.pth')
model.load_state_dict(state_dict)
model = model.to(device)


##-----------------Generate images update-----------------##

Parameter100 = torch.load('state_dict_gene_full_100.pt')
Parameter300 = torch.load('state_dict_gene_full_300.pt')
Parameter500 = torch.load('state_dict_gene_full_500.pt')

# generator.load_state_dict(Parameter300, strict=False)
# generator.eval()

seed = 13
torch.manual_seed(seed)
z_train = torch.randn(100000, 100).to(device)
z_test = torch.randn(100000, 100).to(device)

batch_size = 2000
z_train_batches = torch.split(z_train, batch_size)
z_test_batches = torch.split(z_test, batch_size)
num_batches = len(z_train_batches)


# GAN 100 images
generator.load_state_dict(Parameter100, strict=False)
generator.eval()

Image_100_train = []
Image_100_test = []
Y_100_train = []
Y_100_test = []
for idx in range(num_batches):
    current_z_train = z_train_batches[idx]
    current_z_test = z_test_batches[idx]
    current_100_train = (generator(current_z_train).detach()+1)/2
    current_100_test = (generator(current_z_test).detach()+1)/2
    _, current_Y_100_train = torch.max(model(current_100_train),1)
    _, current_Y_100_test = torch.max(model(current_100_test),1)
    Image_100_train.append(current_100_train)
    Image_100_test.append(current_100_test)
    Y_100_train.append(current_Y_100_train)
    Y_100_test.append(current_Y_100_test)

Image_100_train = torch.cat(Image_100_train, dim=0)
Image_100_test = torch.cat(Image_100_test, dim=0)
Y_100_train = torch.cat(Y_100_train, dim=0)
Y_100_test = torch.cat(Y_100_test, dim=0)
print(Image_100_train.shape)
print(Image_100_test.shape)
print(Y_100_train.shape)
print(Y_100_test.shape)
torch.save({'image':Image_100_train,'label':Y_100_train}, 'GAN_100_train.pth')
torch.save({'image':Image_100_test,'label':Y_100_test}, 'GAN_100_test.pth')
print("100 generation done")


# GAN 300 images
generator.load_state_dict(Parameter300, strict=False)
generator.eval()

Image_300_train = []
Image_300_test = []
Y_300_train = []
Y_300_test = []
for idx in range(num_batches):
    current_z_train = z_train_batches[idx]
    current_z_test = z_test_batches[idx]
    current_300_train = (generator(current_z_train).detach()+1)/2
    current_300_test = (generator(current_z_test).detach()+1)/2
    _, current_Y_300_train = torch.max(model(current_300_train),1)
    _, current_Y_300_test = torch.max(model(current_300_test),1)
    Image_300_train.append(current_300_train)
    Image_300_test.append(current_300_test)
    Y_300_train.append(current_Y_300_train)
    Y_300_test.append(current_Y_300_test)

Image_300_train = torch.cat(Image_300_train, dim=0)
Image_300_test = torch.cat(Image_300_test, dim=0)
Y_300_train = torch.cat(Y_300_train, dim=0)
Y_300_test = torch.cat(Y_300_test, dim=0)
print(Image_300_train.shape)
print(Image_300_test.shape)
print(Y_300_train.shape)
print(Y_300_test.shape)
torch.save({'image':Image_300_train,'label':Y_300_train}, 'GAN_300_train.pth')
torch.save({'image':Image_300_test,'label':Y_300_test}, 'GAN_300_test.pth')
print("300 generation done")


# GAN 500 images
generator.load_state_dict(Parameter500, strict=False)
generator.eval()

Image_500_train = []
Image_500_test = []
Y_500_train = []
Y_500_test = []
for idx in range(num_batches):
    current_z_train = z_train_batches[idx]
    current_z_test = z_test_batches[idx]
    current_500_train = (generator(current_z_train).detach()+1)/2
    current_500_test = (generator(current_z_test).detach()+1)/2
    _, current_Y_500_train = torch.max(model(current_500_train),1)
    _, current_Y_500_test = torch.max(model(current_500_test),1)
    Image_500_train.append(current_500_train)
    Image_500_test.append(current_500_test)
    Y_500_train.append(current_Y_500_train)
    Y_500_test.append(current_Y_500_test)

Image_500_train = torch.cat(Image_500_train, dim=0)
Image_500_test = torch.cat(Image_500_test, dim=0)
Y_500_train = torch.cat(Y_500_train, dim=0)
Y_500_test = torch.cat(Y_500_test, dim=0)
print(Image_500_train.shape)
print(Image_500_test.shape)
print(Y_500_train.shape)
print(Y_500_test.shape)
torch.save({'image':Image_500_train,'label':Y_500_train}, 'GAN_500_train.pth')
torch.save({'image':Image_500_test,'label':Y_500_test}, 'GAN_500_test.pth')
print("500 generation done")

